/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.android.tools.profilers.cpu.systemtrace

import com.android.tools.adtui.model.SeriesData
import com.android.tools.profiler.perfetto.proto.TraceProcessor
import com.android.tools.profilers.cpu.ThreadState
import com.android.tools.profilers.cpu.config.ProfilingConfiguration.TraceType
import perfetto.protos.PerfettoTrace
import java.io.Serializable
import java.lang.Long.max
import java.util.SortedMap
import java.util.TreeMap

/**
 * SystemTraceModelAdapter exposes a common API for accessing the raw model data from system trace
 * captures.
 *
 * This should be used in order to compute data series and nodes that we will display in the UI.
 */
interface SystemTraceModelAdapter {

  fun getCaptureStartTimestampUs(): Long
  fun getCaptureEndTimestampUs(): Long

  fun getProcessById(id: Int): ProcessModel?
  fun getProcesses(): List<ProcessModel>

  /**
   * @return a ThreadModel if we have information for a possible dangling thread with that thread id,
   * which is a thread that we don't have the information about which process they belong to.
   */
  fun getDanglingThread(tid: Int): ThreadModel?

  fun getCpuCores(): List<CpuCoreModel>

  fun getSystemTraceTechnology(): TraceType

  fun getPowerRails(): List<CounterModel>

  fun getBatteryDrain(): List<CounterModel>

  /**
   * @return true if there is potentially missing data from the capture.
   * It's hard to guarantee if data is missing or not, so this is a best guess.
   */
  fun isCapturePossibleCorrupted(): Boolean

  /**
   * @return Android frame events by layer. Supported since Android R.
   */
  fun getAndroidFrameLayers(): List<TraceProcessor.AndroidFrameEventsResult.Layer>

  /**
   * @return Android FrameTimeline events for jank detection. Supported since Android S.
   */
  fun getAndroidFrameTimelineEvents(): List<AndroidFrameTimelineEvent>
}

data class ProcessModel(
  val id: Int,
  val name: String,
  val threadById: Map<Int, ThreadModel>,
  val counterByName: Map<String, CounterModel>): Serializable {

  companion object {
    // generated by serialver
    @JvmStatic
    val serialVersionUID = -7303268805551316288L
  }

  fun getMainThread(): ThreadModel? = threadById[id]
  fun getThreads() = threadById.values

  /**
   * Returns the best assumed name for a process.
   * If the process does not have a name it looks at the name of the main thread, but if we also
   * have no information on the main thread it returns "<PID>" instead.
   */
  fun getSafeProcessName(): String {
    if (name.isNotBlank() && !name.startsWith("<")) {
      return name
    }

    // Fallback to the main thread name
    val mainThreadName = getMainThread()?.name ?: ""
    return if (mainThreadName.isNotBlank()) {
      mainThreadName
    } else {
      "<$id>"
    }
  }
}

data class ThreadModel(
  val id: Int,
  val tgid: Int,
  val name: String,
  val traceEvents: List<TraceEventModel>,
  val schedulingEvents: List<SchedulingEventModel>,
  val threadStateEvents: List<ThreadStateModel>): Serializable {

  companion object {
    // generated by serialver
    @JvmStatic
    val serialVersionUID = -172242225241232844L
  }
}

data class TraceEventModel (
  val name: String,
  val startTimestampUs: Long,
  val endTimestampUs: Long,
  val cpuTimeUs: Long,
  val childrenEvents: List<TraceEventModel>): Serializable {

  companion object {
    // generated by serialver
    @JvmStatic
    val serialVersionUID = -2496458872098350643L
  }
}

data class SchedulingEventModel(
  val state: ThreadState,
  val startTimestampUs: Long,
  val endTimestampUs: Long,
  val durationUs: Long,
  val cpuTimeUs: Long,
  val processId: Int,
  val threadId: Int,
  val core: Int): Serializable {

  companion object {
    // generated by serialver
    @JvmStatic
    val serialVersionUID = -2824580527083393668L
  }
}

data class CounterModel(
  val name: String,
  val valuesByTimestampUs: SortedMap<Long, Double>): Serializable {

  companion object {
    // generated by serialver
    @JvmStatic
    val serialVersionUID = -305070684614208809L
  }
}

data class ThreadStateModel(
  val state: ThreadState,
  val startTimestampUs: Long,
  val endTimestampUs: Long): Serializable {

  companion object {
    // generated by serialver
    @JvmStatic
    val serialVersionUID = 4037744707206656133L
  }
}

object CounterDataUtils {

  /**
   * This utility function takes in a list of counters and a map indicating the desired grouping.
   * It will return a new map with grouped counters' data aggregated.
   */
  @JvmStatic
  fun aggregateCounters(counters: List<CounterModel>,
                        groupingMap: Map<String, String>,
                        normalizeStartTime: Boolean): SortedMap<String, List<SeriesData<Long>>> {
    val aggregatableCounters = convertCountersToAggregatableFormat(counters)

    // groupToAggregatedData is a map from the group name (String) to a map of the aggregated counter
    // data (TreeMap). This TreeMap maps a timestamp (Long) to a corresponding value (Double), and is
    // sorted by timestamp (the key) in increasing order.
    val groupToAggregatedData = mutableMapOf<String, TreeMap<Long, Double>>()
    // Mapping of each group to the normalized start time. The normalized start time is the maximum
    // start time of counters included in a group.
    val groupNameToNormalizedStartTime = mutableMapOf<String, Long>()
    // If the counter does not belong to a group, then it is not included in the output.
    aggregatableCounters.filter { groupingMap.contains(it.key) }.forEach {
      val groupName = groupingMap[it.key]!!
      val newCounterData = it.value
      if (groupToAggregatedData.contains(groupName)) {
        // Update the mapping for an already existent group entry.
        val existingMap = groupToAggregatedData[groupName]!!
        val existingSortedMap = existingMap.toSortedMap()

        // Update the maximum minimum timestamp for a group.
        val minTs = newCounterData.firstKey()
        groupNameToNormalizedStartTime[groupName] = max(groupNameToNormalizedStartTime.getOrDefault(groupName, Long.MIN_VALUE), minTs)

        // Aggregate the current aggregated counter data with the new counter data, sorted by ts.
        existingSortedMap.putAll(it.value)

        val newSortedMap = TreeMap<Long, Double>()

        for (ts in existingSortedMap.keys) {
          val inCurrentGroup = existingMap.contains(ts)
          val inNewCounter = newCounterData.contains(ts)

          // If the current timestamp is in both the current group and the new (to be merged) counter,
          // then we sum the two values at the timestamp and set it as the timestamp's value in the
          // aggregated data map.
          if (inCurrentGroup && inNewCounter) {
            newSortedMap[ts] = existingMap[ts]!! + newCounterData[ts]!!
          }
          // If the current timestamp is in the current group but not in the new (to be merged) counter,
          // then we set the current timestamp's value as the current group's value in the aggregated map.
          // And, if it exists, we also add on the value of the new counter at the maximum timestamp before
          // the current timestamp. This allows us to account for the new counter's contribution to the
          // accumulated value.
          else if (inCurrentGroup) {
            val pastCurrentCounterVal = newCounterData.floorEntry(ts)
            newSortedMap[ts] = (if (pastCurrentCounterVal != null) pastCurrentCounterVal.value else 0.0) + existingMap[ts]!!
          }
          // If the current timestamp is in the new (to be merged) counter and not in the current group,
          // then we will set the current timestamp's value as the new counter's value in the aggregated
          // map. And, if it exists, we add on the value of the current group at the maximum timestamp
          // before the current timestamp. This allows us to account for the current group's contribution
          // to the accumulated value.
          else if (inNewCounter) {
            val pastCurrentGroupVal = existingMap.floorEntry(ts)
            newSortedMap[ts] = (if (pastCurrentGroupVal != null) pastCurrentGroupVal.value else 0.0) + newCounterData[ts]!!
          }
        }

        groupToAggregatedData[groupName] = newSortedMap
      }
      else {
        // Update the maximum minimum timestamp for a group.
        val minTs = it.value.firstKey()
        groupNameToNormalizedStartTime[groupName] = max(groupNameToNormalizedStartTime.getOrDefault(groupName, Long.MIN_VALUE), minTs)
        // Start new map entry with new group name & data using a singular counter.
        groupToAggregatedData[groupName] = it.value
      }
    }

    if (normalizeStartTime) {
      removeDataBeforeNormalizedStartTime(groupToAggregatedData, groupNameToNormalizedStartTime)
    }

    return groupToAggregatedData.mapValues {
      val counterModel = CounterModel(it.key, it.value)
      convertCounterToSeriesData(counterModel)
    }.toSortedMap()
  }

  /**
   * In order to perform the aggregation in 'aggregateCounters', the 'valuesByTimestampUs' field
   * of the CounterModel must be of type TreeMap<Long, Double>, rather than SortedMap<Long, Double>.
   * This helper function does this type conversion to guarantee we have data in aggregatable format.
   *
   * @param counters - list of CounterModel's to be transformed
   * @return a mapping from the counter name (String) to a map of the counter data (TreeMap). This TreeMap maps a
   * timestamp (Long) to a corresponding value (Double), and is sorted by timestamp (the key) in increasing order.
   */
  private fun convertCountersToAggregatableFormat(counters: List<CounterModel>): Map<String, TreeMap<Long, Double>> {
    return counters.associate {
      it.name to TreeMap<Long, Double>(it.valuesByTimestampUs)
    }
  }

  /**
   * This function clips a grouped counter's series data to start at the maximum start time of the
   * grouped counters. This is useful when combining groups that have much different magnitudes,
   * as it removes the first points of the combined series that have extreme differences b/w them.
   */
  private fun removeDataBeforeNormalizedStartTime(groupToAggregatedData: MutableMap<String, TreeMap<Long, Double>>,
                                 groupNameToMaxMinTs: Map<String, Long>) {
    val groupNames = groupToAggregatedData.keys
    for (groupName in groupNames) {
      if (groupToAggregatedData.contains(groupName)) {
        val maxMinTs = groupNameToMaxMinTs[groupName]!!
        // To clip off the values with timestamps below the max-min, we set the group's data to a sub map of the original
        // data mapping. This sub map is from [max-min-ts, last-ts + 1). The last timestamp requires a '+ 1' as the 'subMap'
        // method 'toKey' parameter is exclusive, while 'fromKey' is inclusive.
        groupToAggregatedData[groupName] =
          TreeMap(groupToAggregatedData[groupName]!!.subMap(maxMinTs, groupToAggregatedData[groupName]!!.lastKey() + 1))
      }
    }
  }

  /**
   * This utility function takes in counters data map (counter name to a list of series data).
   * It will return a new map where the counter data value at each timestamp is the difference between
   * itself and the last timestamp's value.
   */
  @JvmStatic
  fun convertSeriesDataToDeltaSeries(countersData: SortedMap<String, List<SeriesData<Long>>>): SortedMap<String, List<SeriesData<Long>>> {
    val groupToDeltaSeries = sortedMapOf<String, List<SeriesData<Long>>>()

    for ((groupName, data) in countersData) {
      val deltaData = mutableListOf<SeriesData<Long>>()
      for (i in 0 until data.size - 1) {
        val a = data[i]
        val b = data[i + 1]
        deltaData.add(SeriesData(b.x, b.value - a.value))
      }
      groupToDeltaSeries[groupName] = deltaData
    }

    return groupToDeltaSeries
  }

  fun convertCounterToSeriesData(counter: CounterModel): List<SeriesData<Long>> {
    return counter.valuesByTimestampUs.map { SeriesData(it.key, it.value.toLong()) }.toList()
  }
}

data class CpuCoreModel(
  val id: Int,
  val schedulingEvents: List<SchedulingEventModel>,
  val countersMap: Map<String, CounterModel>) : Serializable {

  companion object {
    // generated by serialver
    @JvmStatic
    val serialVersionUID = 8233672032802842718L
  }
}

/**
 * @param appJankType Raw data may contain multiple jank types but we only extract the app jank type, namely "App Deadline Missed",
 *                    "Buffer Stuffing" and "Unknown Jank". If no jank is present, the value is "None". If no app jank is present, the value
 *                    is "Unspecified".
 */
data class AndroidFrameTimelineEvent(
  val displayFrameToken: Long,
  val surfaceFrameToken: Long,
  val expectedStartUs: Long,
  val expectedEndUs: Long,
  val actualEndUs: Long,
  val layerName: String,
  val presentType: PerfettoTrace.FrameTimelineEvent.PresentType,
  val appJankType: PerfettoTrace.FrameTimelineEvent.JankType,
  val onTimeFinish: Boolean,
  val gpuComposition: Boolean,
  val layoutDepth: Int
) : Serializable {
  val expectedDurationUs get() = expectedEndUs - expectedStartUs
  val actualDurationUs get() = actualEndUs - expectedStartUs
  val isJank get() = appJankType != PerfettoTrace.FrameTimelineEvent.JankType.JANK_NONE
  val isActionableJank get() = appJankType == PerfettoTrace.FrameTimelineEvent.JankType.JANK_APP_DEADLINE_MISSED
}

/**
 * Given a list of events X with starts and ends, return a list of padded events SeriesData<Y>,
 * where Y subsumes X and padding values.
 * @param data injects source's event type to target's event type
 * @param pad makes padding event with start and end
 */
fun<X,Y> Iterable<X>.padded(start: (X) -> Long, end: (X) -> Long, data: (X) -> Y, pad: (Long, Long) -> Y): List<SeriesData<Y>> =
  mutableListOf<SeriesData<Y>>().also { paddedEvents ->
    var lastEnd = 0L
    forEach { event ->
      val t = start(event)
      if (lastEnd < t) paddedEvents.add(SeriesData(lastEnd, pad(lastEnd, t))) // add pad if there's gap between events
      lastEnd = end(event)
      paddedEvents.add(SeriesData(t, data(event))) // add real event
    }
    // Add another padding to properly end last event.
    if (lastEnd < Long.MAX_VALUE) {
      paddedEvents.add(SeriesData(lastEnd, pad(lastEnd, Long.MAX_VALUE)))
    }
  }